package com.asggo.g21.utils;

import static com.asggo.g21.constant.CacheConstants.SSE_EMITTER_TOPIC;

import cn.hutool.core.collection.CollUtil;
import com.asggo.g21.dto.SseMessageDTO;
import com.asggo.g21.ex.SseException;
import com.asggo.g21.holder.SseEmitterHolder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.web.socket.PongMessage;

/**
 * Created by IntelliJ IDEA.
 *
 * @author eric 2024/3/27 11:22
 */
@Slf4j
public class SseEmitterUtil {

  private SseEmitterUtil() {
  }

  /**
   * 发送心跳线程池
   */
  private static final ScheduledExecutorService heartbeatExecutors = Executors.newScheduledThreadPool(
      8);

  /**
   * 创建用户连接并返回 SseEmitter
   *
   * @param userId 用户ID
   * @return SseEmitter
   */
  public static SseEmitter connect(String userId) {
    // 设置超时时间，0表示不过期。默认30秒，超过时间未完成会抛出异常：AsyncRequestTimeoutException
    SseEmitter sseEmitter = new SseEmitter(0L);
    // 注册回调
    sseEmitter.onCompletion(completionCallBack(userId));
    sseEmitter.onError(errorCallBack(userId));
    sseEmitter.onTimeout(timeoutCallBack(userId));

    SseEmitterHolder.addSseEmitter(userId, sseEmitter);

    log.info("创建新的sse连接，当前用户：{}", userId);
    heartbeatExecutors.scheduleAtFixedRate(() -> {
      try {
        sseEmitter.send(new PongMessage());
      } catch (IOException e) {
        log.error("心跳发送异常 -{}", ExceptionUtils.getStackTrace(e));
        throw new SseException();
      }
    }, 0L, 10L, TimeUnit.SECONDS);

    return sseEmitter;
  }

  /**
   * 移除用户连接
   */
  public static void removeUser(String userId) {
    SseEmitterHolder.removeSseEmitter(userId);

    log.info("移除用户：{}", userId);
  }

  /**
   * 订阅消息
   *
   * @param consumer 自定义处理
   */
  public static void subscribeMessage(Consumer<SseMessageDTO> consumer) {
    RedisUtils.subscribe(SSE_EMITTER_TOPIC, SseMessageDTO.class, consumer);
  }

  /**
   * 发布订阅的消息
   *
   * @param sseMessage 消息对象
   */
  public static void publishMessage(SseMessageDTO sseMessage) {
    List<String> unsentSessionKeys = new ArrayList<>();
    // 当前服务内session,直接发送消息
    for (String sessionKey : sseMessage.getSessionKeys()) {
      if (SseEmitterHolder.existSession(sessionKey)) {
        sendMessage(sessionKey, sseMessage.getMessage());
        continue;
      }
      unsentSessionKeys.add(sessionKey);
    }
    // 不在当前服务内session,发布订阅消息
    if (CollUtil.isNotEmpty(unsentSessionKeys)) {
      SseMessageDTO broadcastMessage = new SseMessageDTO();
      broadcastMessage.setMessage(sseMessage.getMessage());
      broadcastMessage.setSessionKeys(unsentSessionKeys);
      RedisUtils.publish(
          SSE_EMITTER_TOPIC,
          broadcastMessage,
          consumer ->
              log.info(" SSE 发送主题订阅消息topic:{} session keys:{} message:{}",
                  SSE_EMITTER_TOPIC, unsentSessionKeys, sseMessage.getMessage()
              )
      );
    }
  }

  /**
   * 发布订阅的消息(群发)
   *
   * @param message 消息内容
   */
  public static void publishAll(String message) {
    SseMessageDTO broadcastMessage = new SseMessageDTO();
    broadcastMessage.setMessage(message);
    RedisUtils.publish(
        SSE_EMITTER_TOPIC,
        broadcastMessage,
        consumer
            -> log.info("sse发送主题订阅消息topic:{} message:{}", SSE_EMITTER_TOPIC, message)
    );
  }

  public static void sendMessage(String sessionKey, String message) {
    final SseEmitter sseEmitter = SseEmitterHolder.getSseEmitter(sessionKey);
    if (sseEmitter != null) {
      try {
        sseEmitter.send(message);
      } catch (IOException e) {
        log.error("用户[{}]推送异常:{}", sessionKey, e.getMessage());
        removeUser(sessionKey);
      }
    }
  }


  private static Runnable completionCallBack(String userId) {
    return () -> {
      log.info("结束连接：{}", userId);
      removeUser(userId);
    };
  }

  private static Runnable timeoutCallBack(String userId) {
    return () -> {
      log.info("连接超时：{}", userId);
      removeUser(userId);
    };
  }

  private static Consumer<Throwable> errorCallBack(String userId) {
    return throwable -> {
      log.info("连接异常：{}", userId);
      removeUser(userId);
    };
  }
}
